import torch 
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import pandas as pd
import os
import glob
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc



def pAUC_two_metric(target, pred, max_fpr):
  target = target.reshape(-1)
  pred = pred.reshape(-1)
  idx_pos = np.where(target == 1)[0]
  idx_neg = np.where(target != 1)[0]

  num_pos = round(len(idx_pos)*max_fpr)
  num_neg = round(len(idx_neg)*max_fpr)

  if num_pos<1:
    num_pos=1
  if num_neg<1:
    num_neg=1
  if len(idx_pos)==1: 
    selected_arg_pos = [0]
  elif num_pos == len(idx_pos): 
    selected_arg_pos = np.arange(num_pos)
  else:
    selected_arg_pos = np.argpartition(pred[idx_pos], num_pos)[:num_pos]
  if len(idx_neg)==1: 
    selected_arg_neg = [0]
  elif num_neg == len(idx_neg): 
    selected_arg_neg = np.arange(num_neg)
  else:
    selected_arg_neg = np.argpartition(-pred[idx_neg], num_neg)[:num_neg]

  selected_target = np.concatenate((target[idx_pos][selected_arg_pos], target[idx_neg][selected_arg_neg]))
  selected_pred = np.concatenate((pred[idx_pos][selected_arg_pos], pred[idx_neg][selected_arg_neg]))

  pAUC_score = roc_auc_score(selected_target, selected_pred)
  return pAUC_score


def set_all_seeds(SEED):
    # REPRODUCIBILITY
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def collate_fn(list_items):
     x = []
     y = []
     ids = []
     for x_, y_, ids_ in list_items:
         x.append(x_)
         y.append(y_)
         ids.append(ids_)
     return x, y, ids



class TabularDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y
        self.targets = self.get_labels()
    def __len__(self):
        return len(self.Y)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx], int(idx)

    def get_labels(self):
        return np.concatenate(self.Y, axis=0)


class OCTDataset(Dataset):
    def __init__(self, images, targets, mode='train', image_size=64):
       self.images = images.astype(np.uint8)
       self.targets = targets
       self.mode = mode
       if FLAGS.augmentation:
           self.transform_train = transforms.Compose([                                                
                                  transforms.ToTensor(),
                                  transforms.RandomHorizontalFlip(),
                                  ])
       else:
           self.transform_train = transforms.Compose([                                                
                                  transforms.ToTensor(),
                                  ])
       self.transform_test = transforms.Compose([
                             transforms.ToTensor(),
                             ])
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        target = self.targets[idx]
        if self.mode == 'train':
            image = self.transform_train(image)
        else:
            image = self.transform_test(image)
        return image, target, int(idx)

    def get_labels(self):
        return np.array(self.targets).reshape(-1)



def random_sample_y(list_X, model, batch_size=1, mode='plain', tau=0.1, device=None):
    if not device:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if type(list_X) == list:
      X = torch.from_numpy(np.concatenate(list_X, axis=0)).to(device) 
    else: # it is a tensor
      if isinstance(list_X, np.ndarray):
        list_X = torch.from_numpy(list_X)
      X = list_X.to(device) 
    bag_size = X.shape[0]
    weights = torch.ones(bag_size)
    sample_size = min(bag_size, batch_size)
    ids = torch.multinomial(weights, sample_size, replacement=False)
    X = X[ids,...]
    if mode=='plain':
      y_pred_bag = model(X.float())
      y_pred = torch.mean(y_pred_bag.view([1,-1]), dim=1, keepdim=True)
      return y_pred
    if mode=='max':
      y_pred_bag = model(X.float())
      y_pred = torch.max(y_pred_bag.view([1,-1]), dim=1, keepdim=True).values
      return y_pred
    elif mode=='exp':
      y_pred_bag = torch.exp(model(X.float())/tau)
      y_pred = torch.mean(y_pred_bag.view([1,-1]), dim=1, keepdim=True)
      return y_pred
    elif mode=='att':
      y_pred_bag, weights_bag = model(X.float())
      sn_bag = y_pred_bag * weights_bag
      sn = torch.mean(sn_bag.view([1,-1]), dim=1, keepdim=True)
      sd = torch.mean(weights_bag.view([1,-1]), dim=1, keepdim=True)
      return sn, sd


def instance_max_y(list_X, model, idx=None, mode='max', tau=0.1, device=None):
    # currently handle single data point.
    # can hanlde both list and tensor data.
    if not device:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if idx==None:
      if type(list_X) == list:
        X = torch.from_numpy(np.concatenate(list_X, axis=0)).to(device) 
      else: # it is a tensor
        if isinstance(list_X, np.ndarray):
          list_X = torch.from_numpy(list_X)
        X = list_X.to(device) 
      y_pred_bag = model(X.float())
      if mode == 'att':
        w_pred_bag = y_pred_bag[1]
        y_pred_bag = y_pred_bag[0]
      if mode=='max':
        y_pred = torch.max(y_pred_bag.view([1,-1]), dim=1, keepdim=True).values 
      elif mode=='mean':
        y_pred = torch.mean(y_pred_bag.view([1,-1]), dim=1, keepdim=True)
      elif mode=='softmax':
        y_pred = tau*torch.log(torch.mean(torch.exp(y_pred_bag.view([1,-1])/tau), dim=1, keepdim=True))
      elif mode=='att':
        y_pred = torch.sum(y_pred_bag.view([1,-1]) * torch.nn.functional.normalize(w_pred_bag.view([1,-1]),p=1.0,dim=-1), dim=1, keepdim=True)
      
    else:
      if type(list_X) == list:
        X = torch.from_numpy(list_X[idx]).to(device) 
      else: # it is a tensor
        if isinstance(list_X, np.ndarray):
          list_X = torch.from_numpy(list_X)
        X = list_X[idx].to(device) 
        X = torch.unsqueeze(X,dim=0)
      y_pred = model(X.float())
    
    return y_pred


def evaluate_auc(dataloader, model, mode='max', tau=0.1, debug=False):
  test_pred = []
  test_true = []
  for jdx, data in enumerate(dataloader):
    if True: # list data
      test_data_bags, test_labels, ids = data
      y_pred = []
      for i in range(len(ids)):
        tmp_pred = instance_max_y(test_data_bags[i],model,mode=mode,tau=tau)
        y_pred.append(tmp_pred)
      y_pred = torch.cat(y_pred, dim=0)
    test_pred.append(y_pred.cpu().detach().numpy())
    test_true.append(test_labels)
  test_true = np.concatenate(test_true, axis=0)
  test_pred = np.concatenate(test_pred, axis=0)
  if debug==True: # for debug, it is deprecated
    tmp=np.concatenate([test_true,test_pred],axis=1)
    print(tmp[:50])
  single_te_auc =  roc_auc_score(test_true, test_pred) 
  return single_te_auc


def evaluate_tpauc(dataloader, model, mode='max', tau=0.1, fprs=[0.1, 0.3, 0.5, 0.7, 0.9], debug=False):
  test_pred = []
  test_true = []
  for jdx, data in enumerate(dataloader):
    if True: # list data
      test_data_bags, test_labels, ids = data
      y_pred = []
      for i in range(len(ids)):
        tmp_pred = instance_max_y(test_data_bags[i],model,mode=mode,tau=tau)
        y_pred.append(tmp_pred)
      y_pred = torch.cat(y_pred, dim=0)
    test_pred.append(y_pred.cpu().detach().numpy())
    test_true.append(test_labels)
  test_true = np.concatenate(test_true, axis=0)
  test_pred = np.concatenate(test_pred, axis=0)
  if debug==True: # for debug, it is deprecated
    tmp=np.concatenate([test_true,test_pred],axis=1)
    print(tmp[:50])
  tpaucs = []
  #print(np.concatenate([test_pred, test_true], axis=-1))
  for fpr in fprs:
    tpaucs.append(np.expand_dims(pAUC_two_metric(test_true, test_pred, fpr), axis=0))
  tpaucs = np.concatenate(tpaucs, axis=0)
  return tpaucs


